[Pipelines] Implement Z-Image ModuleV2 pipeline#21
Conversation
Port Z-Image to the Graph API / ModuleV2 runtime using V2 text encoder, transformer, and VAE components. Restore the current ModuleV3 feature set and behavior in the V2 path, including: - Z-Image transformer/model/config/weight adapter wiring - diffusion pipeline, arch registration, and ModuleV2/ModuleV3 selection via --prefer-module-v3 - batched CFG, CFG renormalization, and image parity with ModuleV3 - transformer-side RoPE micro-optimizations: - single unified RoPE embedder call - interleaved [cos, sin] frequency generation - rope_ragged_with_position_ids hot path - preamble dtype cast and direct modulation slicing stack-info: PR: #21, branch: byungchul-sqzb/stack/2
c58e9d4 to
3bb144c
Compare
f76185d to
23b1c55
Compare
There was a problem hiding this comment.
Code Review
This pull request implements the Z-Image diffusion architecture, providing both standard and ModuleV3 versions. The changes include the core DiT model, attention mechanisms with rotary embeddings, and the generation pipeline. Review feedback identifies a critical shape mismatch in the attention layer's position IDs, a logic error in the CFG renormalization process where the target norm is incorrectly calculated, and a performance bottleneck caused by frequent host-to-device buffer transfers within the execution loop.
| position_ids = ops.range( | ||
| 0, seq_len, dtype=DType.uint32, device=query.device | ||
| ) | ||
| position_ids = ops.broadcast_to( | ||
| ops.unsqueeze(position_ids, 0), [batch_size, seq_len] | ||
| ) |
There was a problem hiding this comment.
The position_ids tensor must be flattened to 1D to match the first dimension of the ragged input tensors (query_ragged, key_ragged) passed to the rope_ragged_with_position_ids kernel. Currently, it is a 2D tensor of shape [batch_size, seq_len], which will cause a shape mismatch or incorrect indexing in the kernel.
| position_ids = ops.range( | |
| 0, seq_len, dtype=DType.uint32, device=query.device | |
| ) | |
| position_ids = ops.broadcast_to( | |
| ops.unsqueeze(position_ids, 0), [batch_size, seq_len] | |
| ) | |
| position_ids = ops.range( | |
| 0, seq_len, dtype=DType.uint32, device=query.device | |
| ) | |
| position_ids = ops.broadcast_to( | |
| ops.unsqueeze(position_ids, 0), [batch_size, seq_len] | |
| ) | |
| position_ids = ops.reshape(position_ids, [batch_size * seq_len]) |
| with Tracer("transformer"): | ||
| noise_pred = self.transformer( | ||
| latents, | ||
| prompt_embeds, | ||
| timestep, | ||
| img_ids, | ||
| txt_ids, | ||
| )[0] | ||
| assert negative_prompt_embeds is not None | ||
| with Tracer("cfg_transformer"): | ||
| neg_noise_pred = self.transformer( | ||
| latents, | ||
| negative_prompt_embeds, | ||
| timestep, | ||
| neg_img_ids, | ||
| neg_txt_ids, | ||
| )[0] | ||
| assert guidance_buf is not None | ||
| noise_pred = self._cfg_combine( | ||
| noise_pred, neg_noise_pred, guidance_buf | ||
| ) | ||
| if model_inputs.cfg_normalization: | ||
| noise_pred = self._cfg_renormalization( | ||
| noise_pred, | ||
| noise_pred, | ||
| ) | ||
| else: |
There was a problem hiding this comment.
In the non-batched CFG path (used when explicit_negative_prompt is True), the renormalization logic is incorrect. It passes the CFG result (noise_pred) as both the pos and pred arguments to _cfg_renormalization. This makes the renormalization a no-op because it uses the norm of the CFG result as the target norm. It should instead use the norm of the unconditioned (positive) prediction.
elif apply_cfg:
with Tracer("transformer"):
pos_noise_pred = self.transformer(
latents,
prompt_embeds,
timestep,
img_ids,
txt_ids,
)[0]
assert negative_prompt_embeds is not None
with Tracer("cfg_transformer"):
neg_noise_pred = self.transformer(
latents,
negative_prompt_embeds,
timestep,
neg_img_ids,
neg_txt_ids,
)[0]
assert guidance_buf is not None
noise_pred = self._cfg_combine(
pos_noise_pred, neg_noise_pred, guidance_buf
)
if model_inputs.cfg_normalization:
noise_pred = self._cfg_renormalization(
pos_noise_pred,
noise_pred,
)| cfg_timestep_bufs = [ | ||
| Buffer.from_dlpack( | ||
| np.full((2 * batch_size,), float(t), dtype=np.float32) | ||
| ).to(device) | ||
| for t in transformed | ||
| ] |
There was a problem hiding this comment.
Creating a list of buffers on the host and uploading them to the device inside the execute method is inefficient, as it incurs host-side allocation and device transfer overhead for every denoising step. These buffers should be pre-allocated and uploaded in prepare_inputs, or ideally, generated on-device using graph operations (e.g., slicing and broadcasting the existing all_timesteps buffer).
[Pipelines] Implement Z-Image ModuleV2 pipeline
Port Z-Image to the Graph API / ModuleV2 runtime using V2 text encoder,
transformer, and VAE components.
Restore the current ModuleV3 feature set and behavior in the V2 path,
including:
--prefer-module-v3